import copy
import random
import torch

from torch.utils.data import Dataset
from .cache import getCt, getCtRawCandidate
from ..aug import getCtAugmentedCandidate
from utils.candidate import getCandidateInfoList
from utils.logconf import logging

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


class LunaDataset(Dataset):
    def __init__(
        self,
        val_stride=0,
        isValSet_bool=None,
        series_uid=None,
        sortby_str="random",
        ratio_int=0,
        augmentation_dict=None,
        candidateInfo_list=None,
    ):
        self.ratio_int = ratio_int
        self.augmentation_dict = augmentation_dict

        if candidateInfo_list:
            self.candidateInfo_list = copy.copy(candidateInfo_list)
            self.use_cache = False
        else:
            # Copies the return value so the cached copy won't be impacted by
            # altering self.candidateInfo_list
            self.candidateInfo_list = copy.copy(getCandidateInfoList())
            self.use_cache = True

        # If we pass in a truthy series_uid, then the instance will only have
        # nodules from that series. This can be useful for visualization or
        # debugging, by making it easier to look at, for instance, a single
        # problematic CT scan.
        if series_uid:
            self.series_list = [series_uid]
        else:
            self.series_list = sorted(
                set(
                    candidateInfo_tup.series_uid
                    for candidateInfo_tup in self.candidateInfo_list
                )
            )

        if isValSet_bool:
            assert val_stride > 0, val_stride
            self.series_list = self.series_list[::val_stride]
            assert self.series_list
        elif val_stride > 0:
            del self.series_list[::val_stride]
            assert self.series_list

        series_set = set(self.series_list)
        self.candidateInfo_list = [
            x for x in self.candidateInfo_list if x.series_uid in series_set
        ]

        if sortby_str == "random":
            random.shuffle(self.candidateInfo_list)
        elif sortby_str == "series_uid":
            self.candidateInfo_list.sort(key=lambda x: (x.series_uid, x.center_xyz))
        elif sortby_str == "label_and_size":
            pass
        else:
            raise Exception("Unknown sort: " + repr(sortby_str))

        self.neg_list = [nt for nt in self.candidateInfo_list if not nt.isNodule_bool]
        self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
        self.ben_list = [nt for nt in self.pos_list if not nt.isMal_bool]
        self.mal_list = [nt for nt in self.pos_list if nt.isMal_bool]

        log.info(
            "{!r}: {} {} samples, {} neg, {} pos, {} ratio".format(
                self,
                len(self.candidateInfo_list),
                "validation" if isValSet_bool else "training",
                len(self.neg_list),
                len(self.pos_list),
                "{}:1".format(self.ratio_int) if self.ratio_int else "unbalanced",
            )
        )

    def shuffleSamples(self):
        if self.ratio_int:
            random.shuffle(self.candidateInfo_list)
            random.shuffle(self.neg_list)
            random.shuffle(self.pos_list)
            random.shuffle(self.ben_list)
            random.shuffle(self.mal_list)

    def __len__(self):
        if self.ratio_int:
            return 50000
        else:
            return len(self.candidateInfo_list)

    def __getitem__(self, ndx):
        if self.ratio_int:
            pos_ndx = ndx // (self.ratio_int + 1)

            if ndx % (self.ratio_int + 1):
                neg_ndx = ndx - 1 - pos_ndx
                neg_ndx %= len(self.neg_list)
                candidateInfo_tup = self.neg_list[neg_ndx]
            else:
                pos_ndx %= len(self.pos_list)
                candidateInfo_tup = self.pos_list[pos_ndx]
        else:
            candidateInfo_tup = self.candidateInfo_list[ndx]

        return self.sampleFromCandidateInfo_tup(
            candidateInfo_tup, candidateInfo_tup.isNodule_bool
        )

    def sampleFromCandidateInfo_tup(self, candidateInfo_tup, label_bool):
        width_irc = (32, 48, 48)

        if self.augmentation_dict:
            candidate_t, center_irc = getCtAugmentedCandidate(
                self.augmentation_dict,
                candidateInfo_tup.series_uid,
                candidateInfo_tup.center_xyz,
                width_irc,
                self.use_cache,
            )
        elif self.use_cache:
            candidate_a, center_irc = getCtRawCandidate(
                candidateInfo_tup.series_uid,
                candidateInfo_tup.center_xyz,
                width_irc,
            )
            candidate_t = torch.from_numpy(candidate_a).to(torch.float32)
            candidate_t = candidate_t.unsqueeze(0)
        else:
            ct = getCt(candidateInfo_tup.series_uid)
            candidate_a, center_irc = ct.getRawCandidate(
                candidateInfo_tup.center_xyz,
                width_irc,
            )
            candidate_t = torch.from_numpy(candidate_a).to(torch.float32)
            candidate_t = candidate_t.unsqueeze(0)

        label_t = torch.tensor([False, False], dtype=torch.long)

        if not label_bool:
            label_t[0] = True
            index_t = 0
        else:
            label_t[1] = True
            index_t = 1

        return (
            candidate_t,
            label_t,
            index_t,
            candidateInfo_tup.series_uid,
            torch.tensor(center_irc),
        )


class MalignantLunaDataset(LunaDataset):
    def __len__(self):
        if self.ratio_int:
            return 100000
        else:
            return len(self.ben_list + self.mal_list)

    def __getitem__(self, ndx):
        if self.ratio_int:
            if ndx % 2 != 0:
                candidateInfo_tup = self.mal_list[(ndx // 2) % len(self.mal_list)]
            elif ndx % 4 == 0:
                candidateInfo_tup = self.ben_list[(ndx // 4) % len(self.ben_list)]
            else:
                candidateInfo_tup = self.neg_list[(ndx // 4) % len(self.neg_list)]
        else:
            if ndx >= len(self.ben_list):
                candidateInfo_tup = self.mal_list[ndx - len(self.ben_list)]
            else:
                candidateInfo_tup = self.ben_list[ndx]

        return self.sampleFromCandidateInfo_tup(
            candidateInfo_tup, candidateInfo_tup.isMal_bool
        )
